import logging
import os
import pickle
import random
import shutil
import subprocess
import SharedArray

import numpy as np
import torch
from torch import nn
import torch.distributed as dist
import torch.multiprocessing as mp
from .time_utils import *


def check_numpy_to_torch(x):
    if isinstance(x, np.ndarray):
        return torch.from_numpy(x).float(), True
    return x, False


def limit_period(val, offset=0.5, period=np.pi):
    val, is_numpy = check_numpy_to_torch(val)
    ans = val - torch.floor(val / period + offset) * period
    return ans.numpy() if is_numpy else ans


def drop_info_with_name(info, name):
    ret_info = {}
    keep_indices = [i for i, x in enumerate(info['name']) if x != name]
    for key in info.keys():
        ret_info[key] = info[key][keep_indices]
    return ret_info


def rotate_points_along_z(points, angle):
    """
    Args:
        points: (B, N, 3 + C)
        angle: (B), angle along z-axis, angle increases x ==> y
    Returns:

    """
    points, is_numpy = check_numpy_to_torch(points)
    angle, _ = check_numpy_to_torch(angle)

    cosa = torch.cos(angle)
    sina = torch.sin(angle)
    zeros = angle.new_zeros(points.shape[0])
    ones = angle.new_ones(points.shape[0])
    rot_matrix = torch.stack((
        cosa, sina, zeros,
        -sina, cosa, zeros,
        zeros, zeros, ones
    ), dim=1).view(-1, 3, 3).float()
    points_rot = torch.matmul(points[:, :, 0:3], rot_matrix)
    points_rot = torch.cat((points_rot, points[:, :, 3:]), dim=-1)
    return points_rot.numpy() if is_numpy else points_rot


def mask_points_by_range(points, limit_range):
    mask = (points[:, 0] >= limit_range[0]) & (points[:, 0] <= limit_range[3]) \
           & (points[:, 1] >= limit_range[1]) & (points[:, 1] <= limit_range[4])
    return mask


def get_voxel_centers(voxel_coords, downsample_times, voxel_size, point_cloud_range):
    """
    Args:
        voxel_coords: (N, 3)
        downsample_times:
        voxel_size:
        point_cloud_range:

    Returns:

    """
    assert voxel_coords.shape[1] == 3
    voxel_centers = voxel_coords[:, [2, 1, 0]].float()  # (xyz)
    voxel_size = torch.tensor(voxel_size, device=voxel_centers.device).float() * downsample_times
    pc_range = torch.tensor(point_cloud_range[0:3], device=voxel_centers.device).float()
    voxel_centers = (voxel_centers + 0.5) * voxel_size + pc_range
    return voxel_centers


def create_logger(log_file=None, rank=0, log_level=logging.INFO):
    logger = logging.getLogger(__name__)
    logger.setLevel(log_level if rank == 0 else 'ERROR')
    formatter = logging.Formatter('%(asctime)s  %(levelname)5s  %(message)s')
    console = logging.StreamHandler()
    console.setLevel(log_level if rank == 0 else 'ERROR')
    console.setFormatter(formatter)
    logger.addHandler(console)
    if log_file is not None:
        file_handler = logging.FileHandler(filename=log_file)
        file_handler.setLevel(log_level if rank == 0 else 'ERROR')
        file_handler.setFormatter(formatter)
        logger.addHandler(file_handler)
    logger.propagate = False
    return logger


def get_pad_params(desired_size, cur_size):
    """
    Get padding parameters for np.pad function
    Args:
        desired_size: int, Desired padded output size
        cur_size: int, Current size. Should always be less than or equal to cur_size
    Returns:
        pad_params: tuple(int), Number of values padded to the edges (before, after)
    """
    assert desired_size >= cur_size

    # Calculate amount to pad
    diff = desired_size - cur_size
    pad_params = (0, diff)

    return pad_params


def keep_arrays_by_name(gt_names, used_classes):
    inds = [i for i, x in enumerate(gt_names) if x in used_classes]
    inds = np.array(inds, dtype=np.int64)
    return inds


def init_dist_slurm(tcp_port, local_rank, backend='nccl'):
    """
    modified from https://github.com/open-mmlab/mmdetection
    Args:
        tcp_port:
        backend:

    Returns:

    """
    proc_id = int(os.environ['SLURM_PROCID'])
    ntasks = int(os.environ['SLURM_NTASKS'])
    node_list = os.environ['SLURM_NODELIST']
    num_gpus = torch.cuda.device_count()
    torch.cuda.set_device(proc_id % num_gpus)
    addr = subprocess.getoutput('scontrol show hostname {} | head -n1'.format(node_list))
    os.environ['MASTER_PORT'] = str(tcp_port)
    os.environ['MASTER_ADDR'] = addr
    os.environ['WORLD_SIZE'] = str(ntasks)
    os.environ['RANK'] = str(proc_id)
    dist.init_process_group(backend=backend)

    total_gpus = dist.get_world_size()
    rank = dist.get_rank()
    return total_gpus, rank


def init_dist_pytorch(tcp_port, local_rank, backend='nccl'):
    if mp.get_start_method(allow_none=True) is None:
        mp.set_start_method('spawn')
    # os.environ['MASTER_PORT'] = str(tcp_port)
    # os.environ['MASTER_ADDR'] = 'localhost'
    num_gpus = torch.cuda.device_count()
    torch.cuda.set_device(local_rank % num_gpus)

    dist.init_process_group(
        backend=backend,
        # init_method='tcp://127.0.0.1:%d' % tcp_port,
        # rank=local_rank,
        # world_size=num_gpus
    )
    rank = dist.get_rank()
    return num_gpus, rank


def get_dist_info(return_gpu_per_machine=False):
    if torch.__version__ < '1.0':
        initialized = dist._initialized
    else:
        if dist.is_available():
            initialized = dist.is_initialized()
        else:
            initialized = False
    if initialized:
        rank = dist.get_rank()
        world_size = dist.get_world_size()
    else:
        rank = 0
        world_size = 1

    if return_gpu_per_machine:
        gpu_per_machine = torch.cuda.device_count()
        return rank, world_size, gpu_per_machine

    return rank, world_size


def merge_results_dist(result_part, size, tmpdir):
    rank, world_size = get_dist_info()
    os.makedirs(tmpdir, exist_ok=True)

    dist.barrier()
    pickle.dump(result_part, open(os.path.join(tmpdir, 'result_part_{}.pkl'.format(rank)), 'wb'))
    dist.barrier()

    if rank != 0:
        return None

    part_list = []
    for i in range(world_size):
        part_file = os.path.join(tmpdir, 'result_part_{}.pkl'.format(i))
        part_list.append(pickle.load(open(part_file, 'rb')))

    ordered_results = []
    for res in zip(*part_list):
        ordered_results.extend(list(res))
    ordered_results = ordered_results[:size]
    shutil.rmtree(tmpdir)
    return ordered_results


def scatter_point_inds(indices, point_inds, shape):
    ret = -1 * torch.ones(*shape, dtype=point_inds.dtype, device=point_inds.device)
    ndim = indices.shape[-1]
    flattened_indices = indices.view(-1, ndim)
    slices = [flattened_indices[:, i] for i in range(ndim)]
    ret[slices] = point_inds
    return ret


def generate_voxel2pinds(sparse_tensor):
    device = sparse_tensor.indices.device
    batch_size = sparse_tensor.batch_size
    spatial_shape = sparse_tensor.spatial_shape
    indices = sparse_tensor.indices.long()
    point_indices = torch.arange(indices.shape[0], device=device, dtype=torch.int32)
    output_shape = [batch_size] + list(spatial_shape)
    v2pinds_tensor = scatter_point_inds(indices, point_indices, output_shape)
    return v2pinds_tensor


def sa_create(name, var):
    x = SharedArray.create(name, var.shape, dtype=var.dtype)
    x[...] = var[...]
    x.flags.writeable = False
    return x


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def make_fc_layers(fc_cfg, input_channels, output_channels=None):
    fc_layers = []
    c_in = input_channels
    for k in range(0, fc_cfg.__len__()):
        fc_layers.extend([
            nn.Linear(c_in, fc_cfg[k], bias=False),
            nn.BatchNorm1d(fc_cfg[k]),
            nn.ReLU(),
        ])
        c_in = fc_cfg[k]
    if output_channels is not None:
        fc_layers.append(nn.Linear(c_in, output_channels, bias=True))
    return nn.Sequential(*fc_layers)


class TestManager:
    _registered_module = dict()

    @staticmethod
    def __register(module, module_name, verbose):
        assert module_name not in TestManager._registered_module, f'{module_name} was already registed'
        TestManager._registered_module[module_name] = [module, verbose]

    @staticmethod
    def add(name=None, verbose=False):
        def register_without_specific_name(module, module_name=name):
            TestManager.__register(module, module.__name__ if name is None else name, verbose)
            return module

        return register_without_specific_name

    @staticmethod
    def run():
        for k, v in TestManager._registered_module.items():
            header = f"---------- [ {k} ] ----------"
            print(header, flush=True)
            with ScopeTimer("test case duration: ", v[1]) as t:
                v[0]()
            # tail = "--" * (len(header) // 2)
            # print(tail, flush=True)
            print('')


def _torch_gather(x: torch.tensor, ind: torch.tensor, ind_bid: torch.tensor = None,
                  channel_first: bool = False) -> torch.tensor:
    assert x.is_contiguous()
    assert ind.is_contiguous()
    if ind_bid is None:  # batch gather
        (b, *s) = ind.shape
        if channel_first:
            ret = x.gather(-1, ind.view(b, 1, -1).expand(-1, x.shape[1], -1)).view(b, -1, *s)
        else:
            ret = x.gather(1, ind.view(b, -1, 1).expand(-1, -1, x.shape[-1])).view(b, *s, -1)

    else:  # stack gather from batched data
        assert ind_bid.is_contiguous()
        ind = ind.long()
        if channel_first:
            ret = x[ind_bid[..., None], :, ind].permute(0, 2, 1)
        else:
            n, c, s = x.shape[1], x.shape[2], ind.shape
            ret = x.view(-1, c)[(ind + (ind_bid * n)[..., None]).view(-1)].view(*s, -1)
        return ret
    return ret


def torch_gather(x: torch.tensor, ind: torch.tensor, ind_bid: torch.tensor = None,
                 channel_first: bool = False) -> torch.tensor:
    """

    Args:
        x: (B, N, C) or (B, C, N) if channel_first is True
        ind: (B, ...) or (M=N1+N2+..., K) if ind_bind not None
        ind_bid: (M=N1+N2+..., K) indicates the batch_ind of each element in ind
        channel_first: indicates the channel order

    Returns:
        ind_bid not None:
            channel_first is True: (M,CK,) non-contiguous
            channel_first is False: (M,K,C) contiguous
        ind_bid is None:
            channel_first is True: (B, C, ...) contiguous
            channel_first is False: (B, ..., C) contiguous
    """

    if torch.onnx.is_in_onnx_export():
        # class Gather(torch.autograd.Function):
        #     @staticmethod
        #     def forward(ctx, x: torch.tensor, ind: torch.tensor, ind_bid: torch.tensor = None,
        #                 channel_first: bool = False):
        #         return _torch_gather(x, ind, ind_bid, channel_first)
        #
        #     @staticmethod
        #     def symbolic(g, x: torch.tensor, ind: torch.tensor, ind_bid: torch.tensor = None,
        #                  channel_first: bool = False):
        #         return g.op('rd3d::Gather', x, ind)
        #
        # return Gather.apply(x, ind, ind_bid, channel_first)
        return _torch_gather(x, ind, ind_bid, channel_first)
    else:
        return _torch_gather(x, ind, ind_bid, channel_first)


gather = torch_gather


def apply1d(mlps, x):
    if x.shape.__len__() > 2:
        b, *n, c = x.shape
        return mlps(x.view(-1, c)).view(b, *n, -1)
    else:
        return mlps(x)
